import sys
sys.path.append("core_validation")
import numpy as np 
from lstm_model_simulation import simulate_lstm_models
from RealPlantSimulation_modified import SimulateTheActualPlant_mz1, SimulateTheActualPlant_mz2, SimulateTheActualPlant_mz3
from Parameters_1D import *
import matplotlib as mpl
from matplotlib import pyplot as plt
totalDepth, axialNodes, totalNodes= spatialVariables_1D_act()
samplingTime, samplingTimeInternal, internalTimeSteps, totTimeSteps = temporalVariable()

font = {'family' : 'monospace', 'weight' : 'bold', 'size'   : 15}
mpl.rc('font', **font)  # pass in the font dict as kwargs
mpl.rcParams['axes.linewidth'] = 3



sequenceLength = 5

#Load the scaling data for the lstm models
data_min_mz1 = np.loadtxt("./core_validation/Weights/data_min_mz1_vrd_lai.txt")
data_max_mz1 = np.loadtxt("./core_validation/Weights/data_max_mz1_vrd_lai.txt")
data_min_mz2 = np.loadtxt("./core_validation/Weights/data_min_mz2_vrd_lai.txt")
data_max_mz2 = np.loadtxt("./core_validation/Weights/data_max_mz2_vrd_lai.txt")
data_min_mz3 = np.loadtxt("./core_validation/Weights/data_min_mz3_vrd_lai.txt")
data_max_mz3 = np.loadtxt("./core_validation/Weights/data_max_mz3_vrd_lai.txt")

#Initial states
x0_mz1 = -9.86*np.ones(totalNodes)
x0_mz2 = -6.42*np.ones(totalNodes)
x0_mz3 = -8.65*np.ones(totalNodes)

current_root_depth = 0.50
def ObtainSimulationMetrics_mz1():
    sim_period = np.random.randint(low=3, high=8)        
    if sim_period == 3:
        irrig_rate = np.random.uniform(-6.0e-07, -4.5e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    
    elif sim_period == 4:
        irrig_rate = np.random.uniform(-6.0e-07, -4.5e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    
    elif sim_period == 5:
        irrig_rate = np.random.uniform(-6.0e-07, -4.5e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    
    elif sim_period == 6:
        irrig_rate = np.random.uniform(-6.0e-07, -4.5e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    
    else:
        irrig_rate = np.random.uniform(-6.0e-07, -4.5e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    
def ObtainSimulationMetrics_mz2():
    sim_period = np.random.randint(low=3, high=8)        
    if sim_period == 3:
        irrig_rate = np.random.uniform(-6.9e-07, -5.0e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    
    elif sim_period == 4:
        irrig_rate = np.random.uniform(-6.9e-07, -5.0e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    
    elif sim_period == 5:
        irrig_rate = np.random.uniform(-6.9e-07, -5.0e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    
    elif sim_period == 6:
        irrig_rate = np.random.uniform(-6.9e-07, -5.0e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth =current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    
    else:
        irrig_rate = np.random.uniform(-6.9e-07, -5.0e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth


def ObtainSimulationMetrics_mz3():
    sim_period = np.random.randint(low=3, high=8)        
    if sim_period == 3:
        irrig_rate = np.random.uniform(-7.2e-07, -5.8e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    
    elif sim_period == 4:
        irrig_rate = np.random.uniform(-7.2e-07, -5.8e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    
    elif sim_period == 5:
        irrig_rate = np.random.uniform(-7.2e-07, -5.8e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    
    elif sim_period == 6:
        irrig_rate = np.random.uniform(-7.2e-07, -5.8e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    
    else:
        irrig_rate = np.random.uniform(-7.2e-07, -5.8e-08)
        ref_evap = np.random.uniform(1.2e-08, 1.04e-07, sim_period)
        crop_coeff = np.random.uniform(0.4, 1.02, sim_period)
        rooting_depth = current_root_depth*np.ones(sim_period)
        return sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth
    


irrig_rates_mz1 = []
ref_evaps_mz1 = []
rooting_depth_train_mz1 = []
sim_periods_mz1 = []

period = 10
for i in range(period):
    print(i)
    sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth = ObtainSimulationMetrics_mz1()
    irrig_rate_ = [0]*sim_period
    irrig_rate_[0] = irrig_rate
    irrig_rates_mz1.extend(irrig_rate_)
    rooting_depth_train_mz1.extend(rooting_depth)
    ref_evaps_mz1.extend(ref_evap)
    sim_periods_mz1.append(sim_period)    
    pass

irrig_rates_mz2 = []
ref_evaps_mz2 = []
rooting_depth_train_mz2 = []
sim_periods_mz2 = []
for i in range(period):
    print(i)
    sim_period, irrig_rate, ref_evap, crop_coeff, rooting_depth = ObtainSimulationMetrics_mz2()
    irrig_rate_ = [0]*sim_period
    irrig_rate_[0] = irrig_rate
    irrig_rates_mz2.extend(irrig_rate_)
    rooting_depth_train_mz2.extend(rooting_depth)
    ref_evaps_mz2.extend(ref_evap)
    sim_periods_mz2.append(sim_period)    
    pass

irrig_rates_mz3 = []
ref_evaps_mz3 = []
rooting_depth_train_mz3 = []
sim_periods_mz3 = []
for i in range(period):
    print(i)
    sim_period, irrig_rate, ref_evap,crop_coeff, rooting_depth = ObtainSimulationMetrics_mz3()
    irrig_rate_ = [0]*sim_period
    irrig_rate_[0] = irrig_rate
    irrig_rates_mz3.extend(irrig_rate_)
    rooting_depth_train_mz3.extend(rooting_depth)
    ref_evaps_mz3.extend(ref_evap)
    sim_periods_mz3.append(sim_period)    
    pass


_crop_coeffs = np.loadtxt('./core_validation/crop_coeff_n.txt')
_lai_factors = np.loadtxt("./core_validation/LAI_factors.txt")
crop_coeffs_mz1 = _crop_coeffs[0:len(ref_evaps_mz1)]
lai_factors_mz1 = _lai_factors[0:len(ref_evaps_mz1)]
for i in range(len(crop_coeffs_mz1)):
    if crop_coeffs_mz1[i] < 0.20:
        crop_coeffs_mz1[i] = 0.20
        pass
    pass

crop_coeffs_mz2 = _crop_coeffs[0:len(ref_evaps_mz2)]
lai_factors_mz2 = _lai_factors[0:len(ref_evaps_mz2)]
for i in range(len(crop_coeffs_mz2)):
    if crop_coeffs_mz2[i] < 0.20:
        crop_coeffs_mz2[i] = 0.20
        pass
    pass

crop_coeffs_mz3 = _crop_coeffs[0:len(ref_evaps_mz3)]
lai_factors_mz3 = _lai_factors[0:len(ref_evaps_mz3)]
#lai_factors_mz3 = np.ones(len(ref_evaps_mz3))
for i in range(len(crop_coeffs_mz3)):
    if crop_coeffs_mz3[i] < 0.20:
        crop_coeffs_mz3[i] = 0.20
        pass
    pass



irrigation_rate_factor = 1.0
#Richards equation simulation
headArray_mz1, head_rz_mz1_RE = SimulateTheActualPlant_mz1(x0_mz1, irrigation_rate_factor*np.array(irrig_rates_mz1), np.array(ref_evaps_mz1), crop_coeffs_mz1, len(irrig_rates_mz1), np.array(rooting_depth_train_mz1), lai_factors_mz1)
headArray_mz2, head_rz_mz2_RE = SimulateTheActualPlant_mz2(x0_mz2, irrigation_rate_factor*np.array(irrig_rates_mz2), np.array(ref_evaps_mz2), crop_coeffs_mz2, len(irrig_rates_mz2), np.array(rooting_depth_train_mz2), lai_factors_mz2)
headArray_mz3, head_rz_mz3_RE = SimulateTheActualPlant_mz3(x0_mz3, irrigation_rate_factor*np.array(irrig_rates_mz3), np.array(ref_evaps_mz3), crop_coeffs_mz3, len(irrig_rates_mz3), np.array(rooting_depth_train_mz3), lai_factors_mz3)

#scale the inputs 
x0_scaled_mz1 = (head_rz_mz1_RE[0:sequenceLength] - data_min_mz1[0])/(data_max_mz1[0] - data_min_mz1[0])
x0_scaled_mz2 = (head_rz_mz2_RE[0:sequenceLength] - data_min_mz2[0])/(data_max_mz2[0] - data_min_mz2[0])
x0_scaled_mz3 = (head_rz_mz3_RE[0:sequenceLength] - data_min_mz3[0])/(data_max_mz3[0] - data_min_mz3[0])


u_scaled_mz1  = (irrigation_rate_factor*np.array(irrig_rates_mz1) -  data_min_mz1[1])/(data_max_mz1[1] - data_min_mz1[1])
kc_scaled_mz1 = (np.array(crop_coeffs_mz1) -  data_min_mz1[2])/(data_max_mz1[2] - data_min_mz1[2])
et_scaled_mz1 = (np.array(ref_evaps_mz1)   -  data_min_mz1[3])/(data_max_mz1[3] - data_min_mz1[3])
rd_scaled_mz1 = (np.array(rooting_depth_train_mz1) - data_min_mz1[4])/(data_max_mz1[4] - data_min_mz1[4])
lai_scaled_mz1 = (lai_factors_mz1 - data_min_mz1[5])/(data_max_mz1[5] - data_min_mz1[5])

u_scaled_mz2  = (irrigation_rate_factor*np.array(irrig_rates_mz2) -  data_min_mz2[1])/(data_max_mz2[1] - data_min_mz2[1])
kc_scaled_mz2 = (np.array(crop_coeffs_mz2) -  data_min_mz2[2])/(data_max_mz2[2] - data_min_mz2[2])
et_scaled_mz2 = (np.array(ref_evaps_mz2)   -  data_min_mz2[3])/(data_max_mz2[3] - data_min_mz2[3])
rd_scaled_mz2 = (np.array(rooting_depth_train_mz2) - data_min_mz2[4])/(data_max_mz2[4] - data_min_mz2[4])
lai_scaled_mz2 = (lai_factors_mz2 - data_min_mz2[5])/(data_max_mz2[5] - data_min_mz2[5])

u_scaled_mz3  = (irrigation_rate_factor*np.array(irrig_rates_mz3) -  data_min_mz3[1])/(data_max_mz3[1] - data_min_mz3[1])
kc_scaled_mz3 = (np.array(crop_coeffs_mz3) -  data_min_mz3[2])/(data_max_mz3[2] - data_min_mz3[2])
et_scaled_mz3 = (np.array(ref_evaps_mz3)   -  data_min_mz3[3])/(data_max_mz3[3] - data_min_mz3[3])
rd_scaled_mz3 = (np.array(rooting_depth_train_mz3) - data_min_mz3[4])/(data_max_mz3[4] - data_min_mz3[4])
lai_scaled_mz3 = (lai_factors_mz3 - data_min_mz3[5])/(data_max_mz3[5] - data_min_mz3[5])

head_rz_mz1_LSTM_a, head_rz_mz2_LSTM_a, head_rz_mz3_LSTM_a =  simulate_lstm_models(x0_scaled_mz1, x0_scaled_mz2, x0_scaled_mz3 ,u_scaled_mz1, kc_scaled_mz1, et_scaled_mz1, rd_scaled_mz1, lai_scaled_mz1, len(irrig_rates_mz1))
head_rz_mz1_LSTM_b, head_rz_mz2_LSTM_b, head_rz_mz3_LSTM_b =  simulate_lstm_models(x0_scaled_mz1, x0_scaled_mz2, x0_scaled_mz3 ,u_scaled_mz2, kc_scaled_mz2, et_scaled_mz2, rd_scaled_mz2, lai_scaled_mz2, len(irrig_rates_mz2))
head_rz_mz1_LSTM_c, head_rz_mz2_LSTM_c, head_rz_mz3_LSTM_c =  simulate_lstm_models(x0_scaled_mz1, x0_scaled_mz2, x0_scaled_mz3 ,u_scaled_mz3, kc_scaled_mz3, et_scaled_mz3, rd_scaled_mz3, lai_scaled_mz3, len(irrig_rates_mz3))


time = np.arange(1,31,1)
fig, axs = plt.subplots(3, figsize=(8,12))
axs[0].plot(time, head_rz_mz1_RE[:30], color='r' ,linestyle='-', linewidth=3)
axs[0].plot(time, head_rz_mz1_LSTM_a[:30], color='b' ,linestyle='-.', linewidth=3)
axs[0].legend(['Actual', 'Predicted'], ncol=2)
axs[0].set_title('LSTM model validation for MZ$_1$', family='monospace', size=18, weight='bold')
axs[0].set_ylabel(r'$\theta^{RZ}$ (m$^3$/m$^3$)', family='monospace', size=18, weight='bold')
# axs[0].set_xlabel('Time (days)')
axs[0].grid(linewidth=0.5, linestyle='-.')
axs[0].set_xlim([1,30])
#axs[0,0].set_xlim(xmin=1)

axs[1].plot(time, head_rz_mz2_RE[:30], color='r' ,linestyle='-', linewidth=3)
axs[1].plot(time, head_rz_mz2_LSTM_b[:30], color='b' ,linestyle='-.', linewidth=3)
axs[1].legend(['Actual', 'Predicted'], ncol=2)
axs[1].set_title('LSTM model validation for MZ$_2$', family='monospace', size=18, weight='bold')
axs[1].set_ylabel(r'$\theta^{RZ}$ (m$^3$/m$^3$)', family='monospace', size=18, weight='bold')
# axs[1].set_xlabel('Time (days)')
axs[1].grid(linewidth=0.5, linestyle='-.')
axs[1].set_xlim([1,30])

axs[2].plot(time, head_rz_mz3_RE[:30], color='r' ,linestyle='-', linewidth=3)
axs[2].plot(time, head_rz_mz3_LSTM_c[:30], color='b' ,linestyle='-.', linewidth=3)
axs[2].legend(['Actual', 'Predicted'], ncol=2)
axs[2].set_title('LSTM model validation for MZ$_3$', family='monospace', size=18, weight='bold')
axs[2].set_ylabel(r'$\theta^{RZ}$ (m$^3$/m$^3$)', family='monospace', size=18, weight='bold')
axs[2].set_xlabel('Time (days)', family='monospace', size=18, weight='bold')
axs[2].grid(linewidth=0.5, linestyle='-.')
axs[2].set_xlim([1,30])
fig.tight_layout()
fig.savefig('./results/lstm_perform_0.50m.pdf')

err_mz1 = head_rz_mz1_RE[:30] - head_rz_mz1_LSTM_a[:30]
err_mz2 = head_rz_mz2_RE[:30] - head_rz_mz2_LSTM_b[:30]
err_mz3 = head_rz_mz3_RE[:30] - head_rz_mz3_LSTM_c[:30]

err_mz1_sq = err_mz1**2
err_mz2_sq = err_mz2**2
err_mz3_sq = err_mz3**2

rmse_mz1 = (np.sum(err_mz1_sq)/len(err_mz1))**0.5
rmse_mz2 = (np.sum(err_mz2_sq)/len(err_mz2))**0.5
rmse_mz3 = (np.sum(err_mz3_sq)/len(err_mz3))**0.5

mean_mz1 = np.mean(head_rz_mz1_RE[:30])
mean_mz2 = np.mean(head_rz_mz2_RE[:30])
mean_mz3 = np.mean(head_rz_mz3_RE[:30])

diff_mz1 = head_rz_mz1_RE[:30] - mean_mz1
diff_mz2 = head_rz_mz2_RE[:30] - mean_mz2
diff_mz3 = head_rz_mz3_RE[:30] - mean_mz3

var_mz1 = diff_mz1**2
var_mz2 = diff_mz2**2
var_mz3 = diff_mz3**2

tss_mz1 = np.sum(var_mz1)
tss_mz2 = np.sum(var_mz2)
tss_mz3 = np.sum(var_mz3)

r2_mz1 = 1 - (np.sum(err_mz1_sq)/tss_mz1)
r2_mz2 = 1 - (np.sum(err_mz2_sq)/tss_mz2)
r2_mz3 = 1 - (np.sum(err_mz3_sq)/tss_mz3)